import numpy as np
import pandas as pd
from netZooPy.ligress.bonobo import Bonobo
from bonobo_competitors_scaled import lioness
from bonobo_competitors_scaled import spcc
from bonobo_competitors_scaled import sweet


def simulate_bonobo(sim_mean, sim_cov, ngene = 100, nsample = 100, nsample_ind = 100):
  sim_mean = sim_mean["x"].to_numpy()

  # Simulate individual gene expression and save individual mean and covariance matrix
  ind_mean = []
  ind_cor = []
  for i in range(nsample):
    expr_sim = np.random.multivariate_normal(sim_mean, sim_cov, nsample_ind).T
    expr_mean = np.mean(expr_sim, 1)
    expr_cor = np.corrcoef(expr_sim)
    ind_mean.append(expr_mean)
    ind_cor.append(expr_cor)

  expression = pd.DataFrame(np.array(ind_mean).transpose())

  bonobo_obj = Bonobo(expression)
  bonobo_obj.run_bonobo(keep_in_memory=True, output_fmt='.txt')
  ind_bonobo = bonobo_obj.bonobos
  
  bonobo_obj = Bonobo(expression)
  bonobo_obj.run_bonobo(keep_in_memory=True, output_fmt='.txt', sparsify=True)
  ind_bonobo_sparse = bonobo_obj.bonobos
  
  sweet_output = sweet(expression.T)

  ind_lioness = []
  ind_spcc = []
  for i in range(nsample):
    lioness_net = lioness(expression, i)
    ind_lioness.append(lioness_net)
    spcc_net = spcc(expression, i)
    ind_spcc.append(spcc_net)


  # Compute RMSE for each bonobo
  mse_bonobo = []
  mse_bonobo_sparse = []
  mse_lioness = []
  mse_spcc = []
  mse_sweet = []
  for i in range(nsample):
    cov_error = (ind_cor[i] - ind_bonobo[i])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_bonobo.append(cov_mse)
    
    cov_error = (ind_cor[i] - ind_bonobo_sparse[i])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_bonobo_sparse.append(cov_mse)

    cov_error = (ind_cor[i] - ind_lioness[i])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_lioness.append(cov_mse)

    cov_error = (ind_cor[i] - ind_spcc[i])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_spcc.append(cov_mse)
    
    a = pd.read_csv('./SWEET_main/simulation_sweet/outputs/'+ str(i) + '.txt', sep = "\t")
    a = np.array(list(a['raw_edge_score']))
    b = np.array(list(ind_cor[i][np.triu_indices(ngene, k=1)]))
    cov_mse = np.mean((a-b)**2)
    mse_sweet.append(cov_mse)

  return mse_bonobo, mse_bonobo_sparse, mse_lioness, mse_spcc, mse_sweet

